题解 P3702 【[SDOI2017]序列计数】

题意:Alice 想要得到一个长度为 $n$ 的序列,序列中的数都是不超过 $m$ 的正整数,而且这 $n$ 个数的和是 $p$ 的倍数。

Alice 还希望,这 $n$ 个数中,至少有一个数是质数。

Alice 想知道,有多少个序列满足她的要求。

$1\leq n \leq 10^9,1\leq m \leq 2\times 10^7,1\leq p\leq 100$。

很显然,要求的方案数可以转化为所有方案数减去不含质数的方案数

对于所有方案,设$f_{i,j}$表示前$i$个数$mod~p$等于$j$的方案数

$f_{i,j}=\sum\limits_{k=1}^{m}f_{i-1,((j-k)\%p+p)\%p}$

对于不含质数的方案,我们只需先预处理$cnt_i$表示$1$~$m$中$\%p==i$的合数的个数,设$g_{i,j}$表示前$i$个数$mod~p$等于$j$且没有质数的方案数。

$g_{i,j}=\sum\limits_{k=0}^{p-1} cnt_{((j-k)\%p+p)\%p}$

然后发现这两个东西显然都可以用矩阵快速幂优化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#include <bits/stdc++.h>
#define inf 0x3f3f3f3f
#define ll long long
#define mod 20170408
#define M 20000070
#define re register
using namespace std;
bool f[M];
int cnt[M],pre[7000000],tot,n,m,p,ans;
inline void add(int &a,int b){
a+=b;
while (a>=mod)a-=mod;
}
struct node{
int a[207][207];
inline void init0(){memset(a,0,sizeof(a));}
inline void init1(){
for (int i=0;i<=p-1;++i)
for (int j=0;j<=p-1;++j)
a[i][j]=i==j;
}
friend node operator *(node a,node b){
node res;res.init0();
for (int i=0;i<=p-1;++i)
for (int k=0;k<=p-1;++k)
for (int j=0;j<=p-1;++j)
add(res.a[i][k],1ll*a.a[i][j]*b.a[j][k]%mod);
return res;
}
friend node operator ^(node x,int p){
node res;res.init1();
while (p){
if (p&1)res=res*x;
x=x*x;
p>>=1;
}
return res;
}
}t,A;
inline int read(){
int x=0,w=0;char ch=getchar();
while (!isdigit(ch))w|=ch=='-',ch=getchar();
while (isdigit(ch))x=(x<<1)+(x<<3)+ch-'0',ch=getchar();
return w?-x:x;
}
void init(){
f[1]=1;
for (int i=2;i<=m;++i){
if (!f[i])pre[++tot]=i,--cnt[i%p];
for (int j=1;j<=tot&&i*pre[j]<=m;++j){
f[i*pre[j]]=1;
if (i%pre[j]==0)break;
}
}
}
signed main(){
// freopen("dodo.in", "r", stdin); freopen("dodo.out", "w", stdout);
n=read();m=read();p=read();
for (int i=1;i<=m;++i)++cnt[i%p];
init();t.init0();
for (re int i=0;i<=p-1;++i)
for (re int j=0;j<=p-1;++j){
int d=((j-i)%p+p)%p;
if (d==0)d=p;//特判,因为序列中的数都是不超过mm的正整数
if (m>=d)t.a[i][j]=(m-d)/p+1;
}
t=t^n;int res=t.a[0][0];
for (int i=0;i<=p-1;++i)
for (int j=0;j<=p-1;++j){
int d=((j-i)%p+p)%p;//余数
t.a[i][j]=cnt[d];
}
t=t^n;
cout<<((res-t.a[0][0])%mod+mod)%mod<<endl;
}